import logging
import pickle

import cv2
import numpy as np
from flask import jsonify, request
from werkzeug.exceptions import BadRequest

from app.model.models import User, Photo, FaceModel
from app.recognition import FaceNetTorch, AntiSpoofingTwoStreamVit
from app.recognition import compare
from app.utils import file_utils
from app.utils import retinex
from . import bp

LOG = logging.getLogger(__name__)

facenet = FaceNetTorch.instance()
anti_spoofing_two_stream_vit = AntiSpoofingTwoStreamVit.instance()


@bp.route('/face/recognition', methods=['POST'])
def face_recognition():
    data = request.json or {}
    if 'photo' not in data:
        raise BadRequest('Must include photo.')
    photo = data['photo']
    turn_on_anti_spoofing = data['turn_on_anti_spoofing'] if 'turn_on_anti_spoofing' in data else False

    # face anti-spoofing
    if turn_on_anti_spoofing:
        image = file_utils.resize_blob_to(blob=photo, size=224)
        image_gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
        image_gray = cv2.cvtColor(image_gray, cv2.COLOR_GRAY2BGR)
        image_msr = retinex.MSRCR(image_gray, sigma_list=[15, 80, 250], G=5.0, b=25.0, alpha=125.0,
                                  beta=46.0, low_clip=0.01, high_clip=0.99)
        image_msr = cv2.cvtColor(image_msr, cv2.COLOR_BGR2GRAY)
        y_pred = anti_spoofing_two_stream_vit.get_prediction(x=image, msr=image_msr)
        if y_pred == 1:
            user = User()
            user.from_dict({'id': 0, 'username': '非活体', 'cell_phone_number': ''})
            return jsonify(success=True, data=user.to_dict(), status_code=200)

    # face recognition
    image = file_utils.resize_blob_to(photo)
    embedding = facenet.get_embeddings(image)
    photos = Photo.query.all()
    x, y = [], []
    for p in photos:
        x.append(pickle.loads(p.embedding))
        y.append(p.user_id)
    x = np.array(x)
    y = np.array(y)
    # LOG.info([embedding.shape, x.shape, y.shape])
    target = compare.face_net_compare(embedding, x, y)
    if target == 0:
        user = User()
        user.from_dict({'id': 0, 'username': '无法识别', 'cell_phone_number': ''})
    else:
        user = User.query.filter_by(id=int(target)).first()
    return jsonify(success=True, data=user.to_dict(), status_code=200)


@bp.route('face/models/<int:task_id>', methods=['GET'])
def get_face_models(task_id):
    if not task_id:
        raise BadRequest()
    face_models = FaceModel.query.filter_by(task_id=task_id).all()
    result = [{'label': face_model.model_name, 'value': face_model.id} for face_model in face_models]
    return jsonify(success=True, data=result, status_code=200)
